HW3 Michal Grotkowski

I have trained two models: Random Forest and Logistic Regression, on the Steel Plates Fault dataset, which consists of 15 numerical features and a binary target. I choose two instances of each class and provide comparisons of LIME explanations of them below.

Instances of class 0

The comparison below depicts two different scenarios: a situation where models predicted different classes, and one where they depicted the same class. In the first scenario LIME explanation states that both models use values of the features V20 and V23 to shift the class prediction towards 0. There is a discrepancy, where Random Forest uses the value of V25 to shift the class prediction towards the true value of zero, while the Logistic Regression model does the opposite.
The LIME explanation of second instance for Random Forest is consistent for the value of feature V19 as it resides in the same range as the first instance.
This time both of the models predict the same class and the LIME explanation suggests that both of them use some of the feature values to predict this class, such as: V17, V23, V25, V26, V27.

In [68]:
print("Instance number %s. Class : %s \n" % (instance, targ))
print("Random Forest explanation:\n")
explainer.explain_instance(inst, predict_fn = forest_model.predict_proba).show_in_notebook()
print("Logistic Regression explanation:\n")
explainer.explain_instance(inst, predict_fn = log_reg.predict_proba).show_in_notebook()
Instance number 100. Class : 0 

Random Forest explanation:

Logistic Regression explanation:

In [70]:
print("Instance number %s. Class : %s \n" % (instance, targ))
print("Random Forest explanation:\n")
explainer.explain_instance(inst, predict_fn = forest_model.predict_proba).show_in_notebook()
print("Logistic Regression explanation:\n")
explainer.explain_instance(inst, predict_fn = log_reg.predict_proba).show_in_notebook()
Instance number 250. Class : 0 

Random Forest explanation:

Logistic Regression explanation:

Instances of class 1

This is the underrepresented class in the dataset and the figures below depict two different instances where the class predicted by two models differs.
In the first instance the Logistic Regression model predicts the correct class. This time Random Forest uses the value of the feature V23 to shift the class prediction towards 0, while Logistic Regression uses this value to shift it towards the correct value of 1.
In the second instance the situation is reversed: this time the Random Forest model predicts the correct class. LIME explanations here are interesting, as both of the models overlap in values of the features that are used to shift the prediction towards the correct class. Although for the Logistic Regression model we do not see the importance of the V16 feature, while for the Random Forest, its value shifts the prediction class towards the correct class.

In [72]:
print("Instance number %s. Class : %s \n" % (instance, targ))
print("Random Forest explanation:\n")
explainer.explain_instance(inst, predict_fn = forest_model.predict_proba).show_in_notebook()
print("Logistic Regression explanation:\n")
explainer.explain_instance(inst, predict_fn = log_reg.predict_proba).show_in_notebook()
Instance number 1940. Class : 1 

Random Forest explanation:

Logistic Regression explanation:

In [74]:
print("Instance number %s. Class : %s \n" % (instance, targ))
print("Random Forest explanation:\n")
explainer.explain_instance(inst, predict_fn = forest_model.predict_proba).show_in_notebook()
print("Logistic Regression explanation:\n")
explainer.explain_instance(inst, predict_fn = log_reg.predict_proba).show_in_notebook()
Instance number 1401. Class : 1 

Random Forest explanation:

Logistic Regression explanation:

Appendix

In [1]:
!git clone https://github.com/adrianstando/imbalanced-benchmarking-set.git
!pip install lime
Cloning into 'imbalanced-benchmarking-set'...
remote: Enumerating objects: 47, done.
remote: Counting objects: 100% (47/47), done.
remote: Compressing objects: 100% (43/43), done.
remote: Total 47 (delta 4), reused 0 (delta 0), pack-reused 0
Receiving objects: 100% (47/47), 7.45 MiB | 3.30 MiB/s, done.
Resolving deltas: 100% (4/4), done.
Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 275.7/275.7 kB 4.7 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from lime) (3.7.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from lime) (1.23.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from lime) (1.11.3)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from lime) (4.66.1)
Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.10/dist-packages (from lime) (1.2.2)
Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.10/dist-packages (from lime) (0.19.3)
Requirement already satisfied: networkx>=2.2 in /usr/local/lib/python3.10/dist-packages (from scikit-image>=0.12->lime) (3.2)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-image>=0.12->lime) (9.4.0)
Requirement already satisfied: imageio>=2.4.1 in /usr/local/lib/python3.10/dist-packages (from scikit-image>=0.12->lime) (2.31.5)
Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.10/dist-packages (from scikit-image>=0.12->lime) (2023.9.26)
Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-image>=0.12->lime) (1.4.1)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from scikit-image>=0.12->lime) (23.2)
Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.18->lime) (1.3.2)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.18->lime) (3.2.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->lime) (1.1.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->lime) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->lime) (4.43.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->lime) (1.4.5)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->lime) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->lime) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->lime) (1.16.0)
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... done
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=dc0c9b1ac4f8bc967103013dbd49f627ad53451ba378d02c3b068838ac5d4e7e
  Stored in directory: /root/.cache/pip/wheels/fd/a2/af/9ac0a1a85a27f314a06b39e1f492bee1547d52549a4606ed89
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1
In [16]:
import pandas as pd
import lime
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
In [43]:
df = pd.read_csv("./imbalanced-benchmarking-set/datasets/steel-plates-fault.csv")
target = pd.get_dummies(df["TARGET"], drop_first = True)[2]
df = df.iloc[:,1:-1]
X = pd.get_dummies(df, drop_first = True)
X.head(), target.head()
Out[43]:
(      V15     V16     V17     V18     V19     V20  V21     V22     V23  \
 0  0.0498  0.2415  0.1818  0.0047  0.4706  1.0000  1.0  2.4265  0.9031   
 1  0.7647  0.3793  0.2069  0.0036  0.6000  0.9667  1.0  2.0334  0.7782   
 2  0.9710  0.3426  0.3333  0.0037  0.7500  0.9474  1.0  1.8513  0.7782   
 3  0.7287  0.4413  0.1556  0.0052  0.5385  1.0000  1.0  2.2455  0.8451   
 4  0.0695  0.4486  0.0662  0.0126  0.2833  0.9885  1.0  3.3818  1.2305   
 
       V24     V25     V26     V27  
 0  1.6435  0.8182 -0.2913  0.5822  
 1  1.4624  0.7931 -0.1756  0.2984  
 2  1.2553  0.6667 -0.1228  0.2150  
 3  1.6532  0.8444 -0.1568  0.5212  
 4  2.4099  0.9338 -0.1992  1.0000  ,
 0    0
 1    0
 2    0
 3    0
 4    0
 Name: 2, dtype: uint8)
In [44]:
forest_model = RandomForestClassifier(n_estimators = 100, max_depth = 10)
forest_model.fit(X, target)

log_reg = make_pipeline(StandardScaler(), LogisticRegression(max_iter = 1000))
log_reg.fit(X, target)


print(accuracy_score(target, forest_model.predict(X)), balanced_accuracy_score(target, forest_model.predict(X)))
print(accuracy_score(target, log_reg.predict(X)), balanced_accuracy_score(target, log_reg.predict(X)))
0.8907779495105616 0.8452852475614158
0.6759402369912416 0.5759183654337423
In [47]:
from lime.lime_tabular import LimeTabularExplainer

explainer = LimeTabularExplainer(X.values, feature_names = X.columns, training_labels = target)
In [57]:
import numpy as np
print(np.where(target == 1))
(array([1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278,
       1279, 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287, 1288, 1289,
       1290, 1291, 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300,
       1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311,
       1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322,
       1323, 1324, 1325, 1326, 1327, 1328, 1329, 1330, 1331, 1332, 1333,
       1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344,
       1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355,
       1356, 1357, 1358, 1359, 1360, 1361, 1362, 1363, 1364, 1365, 1366,
       1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377,
       1378, 1379, 1380, 1381, 1382, 1383, 1384, 1385, 1386, 1387, 1388,
       1389, 1390, 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399,
       1400, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410,
       1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421,
       1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432,
       1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443,
       1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454,
       1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465,
       1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476,
       1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487,
       1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498,
       1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509,
       1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520,
       1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531,
       1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542,
       1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553,
       1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564,
       1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575,
       1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586,
       1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597,
       1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608,
       1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619,
       1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630,
       1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641,
       1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652,
       1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663,
       1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674,
       1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685,
       1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696,
       1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707,
       1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718,
       1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729,
       1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740,
       1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751,
       1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762,
       1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773,
       1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784,
       1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795,
       1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806,
       1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817,
       1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828,
       1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839,
       1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850,
       1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861,
       1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1870, 1871, 1872,
       1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883,
       1884, 1885, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894,
       1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905,
       1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913, 1914, 1915, 1916,
       1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927,
       1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938,
       1939, 1940]),)
In [73]:
instance = 1401
inst, targ = X.iloc[instance], target.iloc[instance]
In [64]:
import warnings

def warn(*args, **kwargs):
    pass
warnings.warn = warn
In [61]:
explainer.explain_instance(inst, predict_fn = log_reg.predict_proba).show_in_notebook()
/usr/local/lib/python3.10/dist-packages/sklearn/base.py:439: UserWarning: X does not have valid feature names, but StandardScaler was fitted with feature names
  warnings.warn(